package it.uniroma2.exp.tools;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.Vector;

public class MultiClassSVMWrapperForCDS {

	public enum OS { LINUX, WIN; }

	public static void main(String [] argv) throws Exception {
		String bin_dir = argv[0];
		String command = argv[1];
		String file_to_treat = argv[2];
		if (command.equals("prepare")) {
			String testing_set = argv[3];
			generateDifferentTrainingSets(file_to_treat,getClassesFromInstanceFile(file_to_treat));
			generateLearnersAndClassifierCommand(bin_dir,file_to_treat,testing_set,getClassesFromInstanceFile(file_to_treat));
		} else if (command.equals("analyze")) {
			String classes = argv[3];
			generateFinalDecisions(file_to_treat,classes);
			computePerformances( 
				loadClassifiedElements(file_to_treat + ".svm") , 
				loadClassifiedElements(file_to_treat + ".out") ,
				getClassesFromInstanceFile(file_to_treat));
		} else if (command.equals("class")) {
			String classes = argv[3];
			generateFinalDecisions(file_to_treat,classes);
		}  else if (command.equals("perf")) {
			computePerformances( 
				loadClassifiedElements(file_to_treat + ".svm") , 
				loadClassifiedElements(file_to_treat + ".out") ,
				getClassesFromInstanceFile(file_to_treat));
		}
	}


	public static Vector<String> getClassesFromInstanceFile(String instance_set) throws Exception {
		BufferedReader in = new BufferedReader(new FileReader(instance_set + ".svm"));
		String line = in.readLine();
		//int i = 0;
		Vector<String> classes = new Vector<String>();
		while (line != null && !line.trim().equals("")) {
			//System.out.print("." + (i++) );
			line = line.substring(0,line.indexOf('\t')).trim();
			if (!classes.contains(line)) classes.add(line);
			line = in.readLine();
		}
		in.close();
		return classes;
	}


	public static void generateDifferentTrainingSets(String training_set, Vector<String> classes) throws Exception {
		//System.out.println(classes);		
		BufferedWriter [] out = new BufferedWriter [classes.size()];
		int i = 0;
		for (String c:classes) {
			out[i]= new BufferedWriter(new FileWriter(training_set + "_" + c +".svm")); 
			i++;
		}

		BufferedReader in = new BufferedReader(new FileReader(training_set + ".svm"));
		String line = in.readLine();
		String c = null;
		String instance = null;
		while (line != null  && !line.trim().equals("") ) {
			c = line.substring(0,line.indexOf('\t')).trim();
			instance = line.substring(line.indexOf('\t')+1,line.length()).trim();
			i = 0;
			for (String cc:classes) {
				if (cc.equals(c)) out[i].write("1\t"+instance + "\n"); 
				else out[i].write("-1\t"+instance + "\n");
				i++;
			}
			line = in.readLine();
		}
		in.close();
		for (i = 0 ; i < out.length ; i++) out[i].close(); 
	}


	public static void generateLearnersAndClassifierCommand(String bin_dir, String training_set, String testing_set, Vector<String> classes) throws Exception {
		File svm_learn = new File(bin_dir,"svm_learn");
		File svm_classify = new File(bin_dir,"svm_classify");
		OS os = (File.separatorChar=='/'? OS.LINUX: OS.WIN);
		//System.out.println(classes);		
		String name = (new File(training_set)).getName();

		String p_left = (os==OS.WIN?"%":"$");
		String p_right = (os==OS.WIN?"%":"");
		
		String set = (os==OS.WIN?"SET ":"");
		
		BufferedWriter command = new BufferedWriter(new FileWriter("svm_multi_learner" + (os==OS.WIN?".bat":""))); 
		command.write((os==OS.WIN?"@echo off":"#!/bin/bash") + "\n\n");
		command.write(set + "SVM_PARAMETERS=" + (os==OS.WIN?"%":"$") + "*\n\n");
		for (String c:classes) {
			command.write("echo ---------------\necho Learning model for " + c +  " class\necho ---------------\n");
			command.write( svm_learn + " " + p_left + "SVM_PARAMETERS" + p_right + " " + training_set + "_" + c + ".svm " + training_set + "_" + c +"_model.svm\n"); 
		}
		command.close();
		
		command = new BufferedWriter(new FileWriter("svm_multi_classifier" + (os==OS.WIN?".bat":""))); 
		command.write((os==OS.WIN?"@echo off":"#!/bin/bash") + "\n\n");
		for (String c:classes) {
			command.write("echo ---------------\necho Classifying for " + c +  " class\necho ---------------\n");
			command.write( svm_classify + " " + testing_set + ".svm " + training_set + "_" + c +"_model.svm " + testing_set + "_" + c +".out" + "\n"); 
		}
		command.close();

		command = new BufferedWriter(new FileWriter(new File((new File(training_set)).getParentFile(),name + "_classes.txt"))); 
		for (String c:classes) {
			command.write( c +":"); 
		}
		command.write("\n");
		command.close();
	}


	public static void generateFinalDecisions(String testing_set, String classes_file) throws Exception {
		BufferedReader classes_in = new BufferedReader(new FileReader(classes_file));
		String [] classes = (classes_in.readLine()).split(":");
		classes_in.close();
		
		BufferedReader [] in = new BufferedReader [classes.length];
		for (int i=0; i < classes.length ; i++ ) {
			in[i] = new BufferedReader(new FileReader(testing_set + "_" + classes[i] +".out")); 
		}
		
		BufferedReader testing = new BufferedReader (new FileReader(testing_set + ".svm"));
		BufferedWriter testing_decisions = new BufferedWriter (new FileWriter(testing_set + ".out"));
		String line = testing.readLine();
		double max = 0;
		int max_class = 0;
		while (line != null) {
			max = Double.NEGATIVE_INFINITY;
			max_class = 0;
			//System.out.print(" " + (counter++) );
			for (int i=0; i < classes.length ; i++ ) {
				double act = new Double(in[i].readLine().trim()); 
				
				//System.out.print(" " + classes[i] + ":" + act);
				if (act > max) {max = act; max_class = i;} 
			}
			//System.out.println( " == "  +  classes[max_class]);
			testing_decisions.write( classes[max_class] + "\t" + max + "\n");
			line = testing.readLine();
		}		
		testing.close();
		testing_decisions.close();
		for (int i = 0 ; i < in.length ; i++) in[i].close(); 
		
	}
	
	public static void computePerformances(Vector<String> oracle, Vector<String> system_decisions, Vector<String> classes) throws Exception {
		System.out.println("--- Computing performances ---");

		int [] [] confusion_matrix = new int [classes.size()][classes.size()];
		for (int i = 0; i < classes.size() ; i++) 
			for (int j = 0; j < classes.size() ; j++) 
				confusion_matrix[i][j]=0;
		for (int i = 0; i < oracle.size() ; i++) {
			//System.out.println(oracle.elementAt(i) +":" +system_decisions.elementAt(i));
			confusion_matrix[classes.indexOf(oracle.elementAt(i))][classes.indexOf(system_decisions.elementAt(i))] +=1;
		}
		int good_decisions = 0;
		for (int i = 0; i < classes.size() ; i++) {
			good_decisions += confusion_matrix[i][i];
		}
		double accuracy = ((double) good_decisions) / oracle.size();
		
		System.out.print("\t");
		for (int i = 0; i < classes.size() ; i++) System.out.print(classes.elementAt(i)+"\t");
		System.out.println();
		
		for (int i = 0; i < classes.size() ; i++) { 
			System.out.print(classes.elementAt(i) + "\t");
			for (int j = 0; j < classes.size() ; j++) 
				System.out.print("" + confusion_matrix[i][j] + "\t");
			System.out.println();
		}
		System.out.println("Accuracy: " + accuracy * 100 + "%");
	}
	
	public static Vector <String> loadClassifiedElements(String instance_set) throws Exception {
		BufferedReader in = new BufferedReader(new FileReader(instance_set));
		String line = in.readLine();
		Vector<String> decisions = new Vector<String>();
		while (line != null) {
			line = line.trim();
			if (line.indexOf('\t') > 0) 
				line = line.substring(0,line.indexOf('\t')).trim();
			decisions.add(line);
			line = in.readLine();
		}
		in.close();
		return decisions;
	}

}

